import argparse
import json
import os
import sys
from math import sqrt
from typing import Dict, List, Tuple

import numpy as np
from scipy import stats
from sklearn.metrics import mean_squared_error, r2_score


def pearsonr_ci(
    x: np.ndarray, y: np.ndarray, alpha: float = 0.05
) -> Tuple[float, float, float, float]:
    """Calculate Pearson correlation with confidence intervals."""
    r, p = stats.pearsonr(x, y)
    r_z = np.arctanh(r)
    se = 1 / np.sqrt(len(x) - 3)
    z = stats.norm.ppf(1 - alpha / 2)
    lo_z, hi_z = r_z - z * se, r_z + z * se
    lo, hi = np.tanh((lo_z, hi_z))
    return r, p, lo, hi


def extract_persona_predictions(persona_responses: Dict) -> List[float]:
    """Extracts valid numeric predictions from a dictionary of persona responses."""
    predictions = []
    if not persona_responses:
        return predictions

    for response in persona_responses.values():
        pred = None
        if isinstance(response, dict):
            for key in ["weighted_prediction", "mean_prediction", "prediction"]:
                if key in response and response[key] is not None:
                    pred = response[key]
                    break
        elif isinstance(response, (int, float)):
            pred = response

        if pred is not None:
            try:
                pred_float = float(pred)
                if 0 <= pred_float <= 10:
                    predictions.append(pred_float)
            except (ValueError, TypeError):
                continue
    return predictions


def calculate_subset_metrics(gt_list: List[float], pred_list: List[float]) -> Dict:
    """Calculates a dictionary of metrics for a given subset of data."""
    if not gt_list:
        return {}

    gt_array, pred_array = np.array(gt_list), np.array(pred_list)
    corr, p, lo, hi = pearsonr_ci(gt_array, pred_array)
    rmse = sqrt(mean_squared_error(gt_array, pred_array))
    r2 = r2_score(gt_array, pred_array)
    acc = np.mean((gt_array > 5) == (pred_array > 5))

    non_zero_mask = gt_array != 0
    if not np.any(non_zero_mask):
        mean_pe = 0.0
    else:
        mean_pe = np.mean(
            100
            * np.abs(gt_array[non_zero_mask] - pred_array[non_zero_mask])
            / gt_array[non_zero_mask]
        )

    return {
        "correlation": corr,
        "correlation_p_value": p,
        "correlation_ci_lower": lo,
        "correlation_ci_upper": hi,
        "rmse": rmse,
        "r2_score": r2,
        "accuracy": acc,
        "mean_percentage_error": mean_pe,
        "sample_size": len(gt_list),
    }


def calculate_and_write_metrics(json_file_path: str, output_file_path: str):
    """Loads data, calculates metrics, and writes them to a file."""
    with open(json_file_path, "r") as f:
        data = json.load(f)

    gt_list_en, pred_list_en = [], []
    gt_list_foreign, pred_list_foreign = [], []

    for d in data:
        gt = d.get("mean_score")
        mean_pred = None

        if "persona_responses" in d:
            persona_predictions = extract_persona_predictions(d["persona_responses"])
            if persona_predictions:
                mean_pred = np.mean(persona_predictions)
        
        if mean_pred is None:
            mean_pred = d.get("overall_mean_prediction")

        if gt is None or mean_pred is None:
            continue

        if "english" in d.get("image", "").lower():
            gt_list_en.append(gt)
            pred_list_en.append(mean_pred)
        else:
            gt_list_foreign.append(gt)
            pred_list_foreign.append(mean_pred)

    results = {}
    if gt_list_en:
        results["English"] = calculate_subset_metrics(gt_list_en, pred_list_en)
    if gt_list_foreign:
        results["Foreign"] = calculate_subset_metrics(gt_list_foreign, pred_list_foreign)

    gt_combined = gt_list_en + gt_list_foreign
    pred_combined = pred_list_en + pred_list_foreign
    if gt_combined:
        results["Combined"] = calculate_subset_metrics(gt_combined, pred_combined)

    with open(output_file_path, "w") as f:
        f.write("=" * 80 + "\n")
        f.write("WEBSITE AESTHETICS EVALUATION METRICS\n")
        f.write("=" * 80 + "\n\n")

        for subset_name, metrics in results.items():
            f.write(f"{subset_name.upper()} WEBSITES METRICS:\n")
            f.write("-" * 40 + "\n")
            f.write(f"Sample Size: {metrics['sample_size']}\n")
            f.write(f"Pearson Correlation: {metrics['correlation']:.4f}\n")
            f.write(f"p-value: {metrics['correlation_p_value']:.6f}\n")
            f.write(
                f"95% CI: [{metrics['correlation_ci_lower']:.4f}, {metrics['correlation_ci_upper']:.4f}]\n"
            )
            f.write(f"RMSE: {metrics['rmse']:.4f}\n")
            f.write(f"R² Score: {metrics['r2_score']:.4f}\n")
            f.write(f"Accuracy (>5): {metrics['accuracy']:.4f}\n")
            f.write(f"Mean Percentage Error: {metrics['mean_percentage_error']:.2f}%\n\n")

        if "Combined" in results:
            f.write("=" * 80 + "\n")
            f.write("SUMMARY\n")
            f.write("=" * 80 + "\n")
            summary = results["Combined"]
            f.write(f"Overall Performance:\n")
            f.write(f"- Total Samples: {summary['sample_size']}\n")
            f.write(f"- Correlation: {summary['correlation']:.4f}\n")
            f.write(f"- RMSE: {summary['rmse']:.4f}\n")
            f.write(f"- Accuracy: {summary['accuracy']:.4f}\n")
            f.write(f"- Mean PE: {summary['mean_percentage_error']:.2f}%\n")

    print(f"\nMetrics calculated successfully and saved to: {output_file_path}")
    if "Combined" in results:
        summary = results["Combined"]
        print("\nSUMMARY:")
        print(f"Total Samples: {summary['sample_size']}")
        print(f"Correlation: {summary['correlation']:.4f}")
        print(f"RMSE: {summary['rmse']:.4f}")
        print(f"Accuracy: {summary['accuracy']:.4f}")
        print(f"Mean PE: {summary['mean_percentage_error']:.2f}%")


def main():
    parser = argparse.ArgumentParser(
        description="Calculate website aesthetics evaluation metrics from JSON file."
    )
    parser.add_argument("json_file", help="Path to the input JSON file.")
    parser.add_argument("output_file", help="Path to the output text file.")
    args = parser.parse_args()

    if not os.path.exists(args.json_file):
        print(f"Error: Input file '{args.json_file}' not found.", file=sys.stderr)
        sys.exit(1)

    try:
        file_size = os.path.getsize(args.json_file) / (1024 * 1024)
        print(f"Processing file: {args.json_file} (Size: {file_size:.2f} MB)")
        calculate_and_write_metrics(args.json_file, args.output_file)
    except Exception as e:
        print(f"An error occurred: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
